# %%

import torch, torch.optim
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
import torchvision.transforms.functional as TF
import torchvision.transforms as TVT
import PIL
import torchvision.datasets


# %%

mat = scipy.io.loadmat("rot_d12.mat")
mnist_data = torchvision.datasets.MNIST(
    root="", train=False, download=True, transform=TVT.Resize(size=14)
)

# %%

x = mat["x"].T
xp = mat["xplus"].T
ang = mat["angles"].T
speed = mat["speedest"]
remask = mat["remask"].flatten()

X = torch.tensor(np.concatenate([x, speed], 1), dtype=torch.float32).cuda()
Y = torch.tensor(xp - x, dtype=torch.float32).cuda()


x_dim = X.shape[1]
y_dim = Y.shape[1]

data = torch.cat([X, Y], 1)

# %%
batch_size = 512
lr = 0.0004
epochs = 10

net = nn.Sequential(
    nn.Linear(x_dim, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, y_dim),
).cuda()

optim = torch.optim.Adam(net.parameters(), lr=lr)

loss_hist = []


train_loader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)
loss_func = nn.MSELoss()
try:
    pbar = tqdm(range(epochs))
    for e in pbar:
        for batch in train_loader:

            x_ = batch[:, :x_dim]
            y_ = batch[:, x_dim:]

            loss = loss_func(net(x_), y_)

            optim.zero_grad()
            loss.backward()
            optim.step()

            loss_hist.append(loss.item())

            pbar.set_description("Loss: {:.3g}".format(loss.item()))


except KeyboardInterrupt:
    pass

plt.axhline(loss_func(Y, 0 * Y))
plt.semilogy(loss_hist)


# %%

### 7 FIG!
samp = 0
mnist_samp = 0


im = TVT.ToTensor()(mnist_data.__getitem__(mnist_samp)[0])[0].numpy() / 2
im_ = im.flatten()[remask]
im = np.zeros([14 ** 2])
im[remask] = im_
im = im.reshape([14, 14])
# plt.imshow(im, vmin=-.2, vmax=.4)
# plt.colorbar()
# plt.show()

fig, ax = plt.subplots(2, 6, figsize=(8, 3))
ax = ax.flatten()

for j, i in enumerate(np.linspace(0, 1.2, 6)):
    im_torch = torch.tensor(im.flatten()[remask], dtype=torch.float32).cuda()
    x = torch.cat([im_torch, i * X[samp, 112:]], 0)
    out = (x[:112] + net(x)).data.cpu().numpy()
    im_out = np.zeros([14 ** 2])
    im_out[remask] = out
    ax[j].imshow(im_out.reshape(14, 14)[:, 1:], vmin=-0.2, vmax=0.4, cmap="Greys")
    # plt.colorbar()
    # plt.title(i)
    ax[j].axis("off")

im = TVT.ToTensor()(mnist_data.__getitem__(mnist_samp)[0])[0].numpy() / 2
im_ = im.flatten()[remask]
im = np.zeros([14 ** 2])
im[remask] = im_
im = im.reshape([14, 14])
# plt.imshow(im, vmin=-.2, vmax=.4)
# plt.colorbar()
# plt.show()

for j, i in enumerate(np.linspace(0, 1.2, 6)):
    imPIL = TVT.ToPILImage()(torch.tensor(im, dtype=torch.float32).reshape(14, 14))
    ax[j + 6].imshow(
        TVT.ToTensor()(
            TF.rotate(imPIL, -i * ang[samp, 0], resample=PIL.Image.BILINEAR)
        )[0],
        vmin=-0.2,
        vmax=0.4,
        cmap="Greys",
    )

    # plt.colorbar()
    # plt.title(i)
    ax[j + 6].axis("off")
plt.tight_layout()
plt.savefig("7_rot.pdf".format(i))
# %%

### 2 FIG!

samp = 0
mnist_samp = 1

im = TVT.ToTensor()(mnist_data.__getitem__(mnist_samp)[0])[0].numpy() / 2
im_ = im.flatten()[remask]
im = np.zeros([14 ** 2])
im[remask] = im_
im = im.reshape([14, 14])
# plt.imshow(im, vmin=-.2, vmax=.4)
# plt.colorbar()
# plt.show()

fig, ax = plt.subplots(2, 6, figsize=(8, 3))
ax = ax.flatten()

for j, i in enumerate(np.linspace(0, 1.2, 6)):
    im_torch = torch.tensor(im.flatten()[remask], dtype=torch.float32).cuda()
    x = torch.cat([im_torch, i * X[samp, 112:]], 0)
    out = (x[:112] + net(x)).data.cpu().numpy()
    im_out = np.zeros([14 ** 2])
    im_out[remask] = out
    ax[j].imshow(im_out.reshape(14, 14), vmin=-0.2, vmax=0.4, cmap="Greys")
    # plt.colorbar()
    # plt.title(i)
    ax[j].axis("off")

im = TVT.ToTensor()(mnist_data.__getitem__(mnist_samp)[0])[0].numpy() / 2
im_ = im.flatten()[remask]
im = np.zeros([14 ** 2])
im[remask] = im_
im = im.reshape([14, 14])
# plt.imshow(im, vmin=-.2, vmax=.4)
# plt.colorbar()
# plt.show()

for j, i in enumerate(np.linspace(0, 1.2, 6)):
    imPIL = TVT.ToPILImage()(torch.tensor(im, dtype=torch.float32).reshape(14, 14))
    ax[j + 6].imshow(
        TVT.ToTensor()(
            TF.rotate(imPIL, -i * ang[samp, 0], resample=PIL.Image.BILINEAR)
        )[0],
        vmin=-0.2,
        vmax=0.4,
        cmap="Greys",
    )

    # plt.colorbar()
    # plt.title(i)
    ax[j + 6].axis("off")
plt.tight_layout()
plt.savefig("2_rot.pdf".format(i))
# %%

### 1 FIG!

samp = 0
mnist_samp = 2

im = TVT.ToTensor()(mnist_data.__getitem__(mnist_samp)[0])[0].numpy() / 2
im_ = im.flatten()[remask]
im = np.zeros([14 ** 2])
im[remask] = im_
im = im.reshape([14, 14])
# plt.imshow(im, vmin=-.2, vmax=.4)
# plt.colorbar()
# plt.show()

fig, ax = plt.subplots(2, 6, figsize=(8, 3))
ax = ax.flatten()

for j, i in enumerate(np.linspace(0, 1.2, 6)):
    im_torch = torch.tensor(im.flatten()[remask], dtype=torch.float32).cuda()
    x = torch.cat([im_torch, i * X[samp, 112:]], 0)
    out = (x[:112] + net(x)).data.cpu().numpy()
    im_out = np.zeros([14 ** 2])
    im_out[remask] = out
    ax[j].imshow(im_out.reshape(14, 14), vmin=-0.2, vmax=0.4, cmap="Greys")
    # plt.colorbar()
    # plt.title(i)
    ax[j].axis("off")

im = TVT.ToTensor()(mnist_data.__getitem__(mnist_samp)[0])[0].numpy() / 2
im_ = im.flatten()[remask]
im = np.zeros([14 ** 2])
im[remask] = im_
im = im.reshape([14, 14])
# plt.imshow(im, vmin=-.2, vmax=.4)
# plt.colorbar()
# plt.show()

for j, i in enumerate(np.linspace(0, 1.2, 6)):
    imPIL = TVT.ToPILImage()(torch.tensor(im, dtype=torch.float32).reshape(14, 14))
    ax[j + 6].imshow(
        TVT.ToTensor()(
            TF.rotate(imPIL, -i * ang[samp, 0], resample=PIL.Image.BILINEAR)
        )[0],
        vmin=-0.2,
        vmax=0.4,
        cmap="Greys",
    )

    # plt.colorbar()
    # plt.title(i)
    ax[j + 6].axis("off")
plt.tight_layout()
plt.savefig("1_rot.pdf".format(i))
# %%
### 3 FIG!

samp = 0
mnist_samp = 200

im = TVT.ToTensor()(mnist_data.__getitem__(mnist_samp)[0])[0].numpy() / 2
im_ = im.flatten()[remask]
im = np.zeros([14 ** 2])
im[remask] = im_
im = im.reshape([14, 14])
# plt.imshow(im, vmin=-.2, vmax=.4)
# plt.colorbar()
# plt.show()

fig, ax = plt.subplots(2, 6, figsize=(8, 3))
ax = ax.flatten()

for j, i in enumerate(np.linspace(0, 1.2, 6)):
    im_torch = torch.tensor(im.flatten()[remask], dtype=torch.float32).cuda()
    x = torch.cat([im_torch, i * X[samp, 112:]], 0)
    out = (x[:112] + net(x)).data.cpu().numpy()
    im_out = np.zeros([14 ** 2])
    im_out[remask] = out
    ax[j].imshow(im_out.reshape(14, 14), vmin=-0.2, vmax=0.7, cmap="Greys")
    # plt.colorbar()
    # plt.title(i)
    ax[j].axis("off")

im = TVT.ToTensor()(mnist_data.__getitem__(mnist_samp)[0])[0].numpy() / 2
im_ = im.flatten()[remask]
im = np.zeros([14 ** 2])
im[remask] = im_
im = im.reshape([14, 14])
# plt.imshow(im, vmin=-.2, vmax=.4)
# plt.colorbar()
# plt.show()

for j, i in enumerate(np.linspace(0, 1.2, 6)):
    imPIL = TVT.ToPILImage()(torch.tensor(im, dtype=torch.float32).reshape(14, 14))
    ax[j + 6].imshow(
        TVT.ToTensor()(
            TF.rotate(imPIL, -i * ang[samp, 0], resample=PIL.Image.BILINEAR)
        )[0],
        vmin=-0.2,
        vmax=0.7,
        cmap="Greys",
    )

    # plt.colorbar()
    # plt.title(i)
    ax[j + 6].axis("off")
plt.tight_layout()
plt.savefig("3_rot.pdf".format(i))
# %%

